{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "from paddle.nn import Linear\n",
    "import paddle.nn.functional as F\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from paddle.vision.transforms import ToTensor  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAEIJJREFUeJzt3X2slGV+xvHvBQoqokI4uohUrLhGqhXNqG3YKLq7qJsq+ocbqfEtu2IalW4KcX1JCmlsYnV3ra5mXVCjtAtbirJKYuuq0VhjaxiUIFapLzkiinAI63tTFX/94zyYI5655zBvz8B9fRIzc57fc8/zOyPXeWbmnplbEYGZ5WdY2Q2YWTkcfrNMOfxmmXL4zTLl8JtlyuE3y5TDv4eS9LKk6WX3Yd1Lnuc3y5PP/GaZcvj3UJJ6JX1P0gJJ/yrpnyV9JOklSd+WdL2kLZLeljRjwLjLJb1S7PumpCt3ut1rJW2S9K6kH0sKSZOL2khJP5O0QdJmSXdL2rfTv7sNjcOfh3OAfwLGAC8Cj9H//34C8HfArwfsuwX4C+AA4HLgNkknAkg6C/gb4HvAZOC0nY7zD8C3galFfQLwt235jaxpfs6/h5LUC/wY+A4wLSK+X2w/B1gKHBgR2yWNBj4ExkTE+4Pczu+ApyLidkn3AZsj4vqiNhl4DTgKeAP4GPjTiHijqP85sCQijmjvb2uN2KvsBqwjNg+4/r/A1ojYPuBngP2B9yWdDcyn/ww+DNgPeKnY51CgOuC23h5wvafYd7WkHdsEDG/R72At5vDbVySNBB4ELgEejojPizP/jjRvAg4bMGTigOtb6f9D8icR8U4n+rXm+Dm/DTQCGAn0AV8UjwJmDKgvAy6XdIyk/RjwfD4ivgQW0f8awcEAkiZIOrNj3dsucfjtKxHxETCH/pD/AfhL4JEB9X8D7gCeAl4H/rMo/V9x+dNi+39J+hB4Aji6I83bLvMLftYwSccA64CREfFF2f3YrvGZ33aJpPMljZA0hv6pvZUO/u7J4bdddSX9rwm8AWwH/qrcdqxRfthvlimf+c0y1dF5/nHjxsWkSZM6eUizrPT29rJ161bV37PJ8Bfv9b6d/ndx3RMRN6f2nzRpEtVqNbWLmTWhUqkMed+GH/ZLGg7cBZwNTAFmSZrS6O2ZWWc185z/ZOD1iHgzIj4DfgvMbE1bZtZuzYR/Al//YMfGYtvXSJotqSqp2tfX18ThzKyVmgn/YC8qfGPeMCIWRkQlIio9PT1NHM7MWqmZ8G/k65/qOgx4t7l2zKxTmgn/KuAoSUdIGgFcyIAPgZhZd2t4qi8ivpB0Nf1fCTUcuC8iXm5ZZ2bWVk3N80fEo8CjLerFzDrIb+81y5TDb5Yph98sUw6/WaYcfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5Yph98sUw6/WaYcfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5Yph98sUw6/WaYcfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5Yph98sUw6/WaYcfrNMNbVKr3W/7du3J+sffPBBW49/55131qx9+umnybHr169P1u+6665kfd68eTVrS5cuTY7dZ599kvXrrrsuWZ8/f36y3g2aCr+kXuAjYDvwRURUWtGUmbVfK878p0fE1hbcjpl1kJ/zm2Wq2fAH8HtJqyXNHmwHSbMlVSVV+/r6mjycmbVKs+GfFhEnAmcDV0k6decdImJhRFQiotLT09Pk4cysVZoKf0S8W1xuAVYAJ7eiKTNrv4bDL2mUpNE7rgMzgHWtaszM2quZV/sPAVZI2nE7SyLi31vS1R5mw4YNyfpnn32WrD/33HPJ+rPPPluz9v777yfHLl++PFkv08SJE5P1a665JllfsWJFzdro0aOTY48//vhk/bTTTkvWdwcNhz8i3gTS95CZdS1P9ZllyuE3y5TDb5Yph98sUw6/Wab8kd4WePHFF5P1M844I1lv98dqu9Xw4cOT9ZtuuilZHzVqVLJ+0UUX1awdeuihybFjxoxJ1o8++uhkfXfgM79Zphx+s0w5/GaZcvjNMuXwm2XK4TfLlMNvlinP87fA4YcfnqyPGzcuWe/mef5TTjklWa83H/7UU0/VrI0YMSI59uKLL07WrTk+85tlyuE3y5TDb5Yph98sUw6/WaYcfrNMOfxmmfI8fwuMHTs2Wb/11luT9ZUrVybrJ5xwQrI+Z86cZD1l6tSpyfoTTzyRrNf7TP26dbWXcrjjjjuSY629fOY3y5TDb5Yph98sUw6/WaYcfrNMOfxmmXL4zTLlef4OOO+885L1et/rX2856bVr19as3XPPPcmx8+bNS9brzePXc+yxx9asLVy4sKnbtubUPfNLuk/SFknrBmwbK+lxSa8Vl+lvdDCzrjOUh/33A2fttO064MmIOAp4svjZzHYjdcMfEc8A23baPBN4oLj+AJB+XGtmXafRF/wOiYhNAMXlwbV2lDRbUlVSta+vr8HDmVmrtf3V/ohYGBGViKj09PS0+3BmNkSNhn+zpPEAxeWW1rVkZp3QaPgfAS4trl8KPNyadsysU+rO80taCkwHxknaCMwHbgaWSfoRsAG4oJ1N7ukOOOCApsYfeOCBDY+t9z6ACy+8MFkfNszvE9td1Q1/RMyqUfpui3sxsw7yn22zTDn8Zply+M0y5fCbZcrhN8uUP9K7B1iwYEHN2urVq5Njn3766WS93ld3z5gxI1m37uUzv1mmHH6zTDn8Zply+M0y5fCbZcrhN8uUw2+WKc/z7wFSX6+9aNGi5NgTTzwxWb/iiiuS9dNPPz1Zr1QqNWtXXXVVcqykZN2a4zO/WaYcfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5Ypz/Pv4Y488shk/f7770/WL7/88mR98eLFDdc/+eST5NhLLrkkWR8/fnyybmk+85tlyuE3y5TDb5Yph98sUw6/WaYcfrNMOfxmmfI8f+bOP//8ZH3y5MnJ+ty5c5P11Pf+X3/99cmxb731VrJ+4403JusTJkxI1nNX98wv6T5JWyStG7BtgaR3JK0p/vtBe9s0s1YbysP++4GzBtl+W0RMLf57tLVtmVm71Q1/RDwDbOtAL2bWQc284He1pLXF04IxtXaSNFtSVVK1r6+vicOZWSs1Gv5fAUcCU4FNwM9r7RgRCyOiEhGVnp6eBg9nZq3WUPgjYnNEbI+IL4FFwMmtbcvM2q2h8Esa+FnK84F1tfY1s+5Ud55f0lJgOjBO0kZgPjBd0lQggF7gyjb2aCU67rjjkvVly5Yl6ytXrqxZu+yyy5Jj77777mT9tddeS9Yff/zxZD13dcMfEbMG2XxvG3oxsw7y23vNMuXwm2XK4TfLlMNvlimH3yxTioiOHaxSqUS1Wu3Y8ay7jRw5Mln//PPPk/W99947WX/sscdq1qZPn54cu7uqVCpUq9UhrW3uM79Zphx+s0w5/GaZcvjNMuXwm2XK4TfLlMNvlil/dbclrV27Nllfvnx5sr5q1aqatXrz+PVMmTIlWT/11FObuv09nc/8Zply+M0y5fCbZcrhN8uUw2+WKYffLFMOv1mmPM+/h1u/fn2y/stf/jJZf+ihh5L19957b5d7Gqq99kr/8xw/fnyyPmyYz20pvnfMMuXwm2XK4TfLlMNvlimH3yxTDr9Zphx+s0wNZYnuicBi4FvAl8DCiLhd0ljgX4BJ9C/T/cOI+EP7Ws1Xvbn0JUuW1KzdeeedybG9vb2NtNQSJ510UrJ+4403JuvnnntuK9vJzlDO/F8AcyPiGODPgKskTQGuA56MiKOAJ4ufzWw3UTf8EbEpIl4orn8EvAJMAGYCDxS7PQCc164mzaz1duk5v6RJwAnA88AhEbEJ+v9AAAe3ujkza58hh1/S/sCDwE8i4sNdGDdbUlVSta+vr5EezawNhhR+SXvTH/zfRMSOT3psljS+qI8Htgw2NiIWRkQlIio9PT2t6NnMWqBu+CUJuBd4JSJ+MaD0CHBpcf1S4OHWt2dm7TKUj/ROAy4GXpK0pth2A3AzsEzSj4ANwAXtaXH3t3nz5mT95ZdfTtavvvrqZP3VV1/d5Z5a5ZRTTknWr7322pq1mTNnJsf6I7ntVTf8EfEsUGu97++2th0z6xT/aTXLlMNvlimH3yxTDr9Zphx+s0w5/GaZ8ld3D9G2bdtq1q688srk2DVr1iTrb7zxRkM9tcK0adOS9blz5ybrZ555ZrK+77777nJP1hk+85tlyuE3y5TDb5Yph98sUw6/WaYcfrNMOfxmmcpmnv/5559P1m+55ZZkfdWqVTVrGzdubKinVtlvv/1q1ubMmZMcW+/rsUeNGtVQT9b9fOY3y5TDb5Yph98sUw6/WaYcfrNMOfxmmXL4zTKVzTz/ihUrmqo3Y8qUKcn6Oeeck6wPHz48WZ83b17N2kEHHZQca/nymd8sUw6/WaYcfrNMOfxmmXL4zTLl8JtlyuE3y5QiIr2DNBFYDHwL+BJYGBG3S1oAXAH0FbveEBGPpm6rUqlEtVptumkzG1ylUqFarWoo+w7lTT5fAHMj4gVJo4HVkh4vardFxM8abdTMylM3/BGxCdhUXP9I0ivAhHY3ZmbttUvP+SVNAk4Adnwn1tWS1kq6T9KYGmNmS6pKqvb19Q22i5mVYMjhl7Q/8CDwk4j4EPgVcCQwlf5HBj8fbFxELIyISkRUenp6WtCymbXCkMIvaW/6g/+biHgIICI2R8T2iPgSWASc3L42zazV6oZfkoB7gVci4hcDto8fsNv5wLrWt2dm7TKUV/unARcDL0nasdb0DcAsSVOBAHqB9DrVZtZVhvJq/7PAYPOGyTl9M+tufoefWaYcfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5Yph98sUw6/WaYcfrNMOfxmmXL4zTLl8JtlyuE3y1Tdr+5u6cGkPuCtAZvGAVs71sCu6dbeurUvcG+NamVvh0fEkL4vr6Ph/8bBpWpEVEprIKFbe+vWvsC9Naqs3vyw3yxTDr9ZpsoO/8KSj5/Srb11a1/g3hpVSm+lPuc3s/KUfeY3s5I4/GaZKiX8ks6StF7S65KuK6OHWiT1SnpJ0hpJpa4nXqyBuEXSugHbxkp6XNJrxeWgaySW1NsCSe8U990aST8oqbeJkp6S9IqklyX9dbG91Psu0Vcp91vHn/NLGg78D/B9YCOwCpgVEf/d0UZqkNQLVCKi9DeESDoV+BhYHBHHFttuAbZFxM3FH84xEfHTLultAfBx2cu2F6tJjR+4rDxwHnAZJd53ib5+SAn3Wxln/pOB1yPizYj4DPgtMLOEPrpeRDwDbNtp80zggeL6A/T/4+m4Gr11hYjYFBEvFNc/AnYsK1/qfZfoqxRlhH8C8PaAnzdS4h0wiAB+L2m1pNllNzOIQyJiE/T/YwIOLrmfndVdtr2TdlpWvmvuu0aWu2+1MsI/2NJf3TTfOC0iTgTOBq4qHt7a0Axp2fZOGWRZ+a7Q6HL3rVZG+DcCEwf8fBjwbgl9DCoi3i0utwAr6L6lxzfvWCG5uNxScj9f6aZl2wdbVp4uuO+6abn7MsK/CjhK0hGSRgAXAo+U0Mc3SBpVvBCDpFHADLpv6fFHgEuL65cCD5fYy9d0y7LttZaVp+T7rtuWuy/lHX7FVMY/AsOB+yLi7zvexCAk/TH9Z3voX8F4SZm9SVoKTKf/I5+bgfnA74BlwB8BG4ALIqLjL7zV6G06/Q9dv1q2fcdz7A739h3gP4CXgC+LzTfQ//y6tPsu0dcsSrjf/PZes0z5HX5mmXL4zTLl8JtlyuE3y5TDb5Yph98sUw6/Wab+HyKgikXuJBfLAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "图像数据形状和对应数据为 (1, 28, 28)\n",
      "图像数据形状和对应数据为 (1,) [5]\n",
      "\n",
      "输入第一批次的第一个图像，对应的标签数字是[5]\n"
     ]
    }
   ],
   "source": [
    "train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=ToTensor())  \n",
    "train_data0 = np.array(train_dataset[0][0])\n",
    "train_label_0 = np.array(train_dataset[0][1])\n",
    "plt.figure(\"Image\")  \n",
    "plt.imshow(train_data0.squeeze(), cmap=plt.cm.binary)\n",
    "plt.axis('on')  \n",
    "plt.title('image')  \n",
    "plt.show()\n",
    "print(\"图像数据形状和对应数据为\", train_data0.shape)\n",
    "print(\"图像数据形状和对应数据为\", train_label_0.shape, train_label_0)\n",
    "print(\"\\n输入第一批次的第一个图像，对应的标签数字是{}\".format(train_label_0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "from paddle.nn import Linear  \n",
    "import paddle.nn.functional as F  \n",
    "import os  \n",
    "import numpy as np  \n",
    "import matplotlib.pyplot as plt  \n",
    "from paddle.vision.transforms import ToTensor  \n",
    "class MNIST(paddle.nn.Layer):\n",
    "    def __init__(self):\n",
    "        super(MNIST, self).__init__()\n",
    "        self.fc1 = Linear(in_features=784, out_features=100)\n",
    "        self.fc2 = Linear(in_features=100, out_features=100)\n",
    "        self.fc3 = Linear(in_features=100, out_features=10)\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        outputs1 = self.fc1(inputs)\n",
    "        outputs1 = F.relu(outputs1)\n",
    "        outputs2 = self.fc2(outputs1)\n",
    "        outputs2 = F.relu(outputs2)\n",
    "        outputs_final = self.fc3(outputs2)\n",
    "        outputs_final = F.softmax(outputs_final)\n",
    "        return outputs_final\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "paddle.vision.set_image_backend('cv2')\n",
    "def norm_img(img):\n",
    "    assert len(img.shape) == 3\n",
    "    batch_size, img_h, img_w = img.shape[0], img.shape[1], img.shape[2]\n",
    "    img = img / 255\n",
    "    img = paddle.reshape(img, [batch_size, img_h * img_w])\n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch_id:0, batch_id:0, loss is:[2.304393]\n",
      "epoch_id:0, batch_id:1000, loss is:[2.0591664]\n",
      "epoch_id:0, batch_id:2000, loss is:[1.9079299]\n",
      "epoch_id:0, batch_id:3000, loss is:[1.8845427]\n",
      "epoch_id:1, batch_id:0, loss is:[1.9520712]\n",
      "epoch_id:1, batch_id:1000, loss is:[1.8178248]\n",
      "epoch_id:1, batch_id:2000, loss is:[1.808348]\n",
      "epoch_id:1, batch_id:3000, loss is:[1.6209794]\n",
      "epoch_id:2, batch_id:0, loss is:[1.7841527]\n",
      "epoch_id:2, batch_id:1000, loss is:[1.6413157]\n",
      "epoch_id:2, batch_id:2000, loss is:[1.6855235]\n",
      "epoch_id:2, batch_id:3000, loss is:[1.7203252]\n",
      "epoch_id:3, batch_id:0, loss is:[1.5454392]\n",
      "epoch_id:3, batch_id:1000, loss is:[1.6510302]\n",
      "epoch_id:3, batch_id:2000, loss is:[1.5552189]\n",
      "epoch_id:3, batch_id:3000, loss is:[1.5810901]\n",
      "epoch_id:4, batch_id:0, loss is:[1.7036204]\n",
      "epoch_id:4, batch_id:1000, loss is:[1.6404445]\n",
      "epoch_id:4, batch_id:2000, loss is:[1.5511832]\n",
      "epoch_id:4, batch_id:3000, loss is:[1.5887402]\n"
     ]
    }
   ],
   "source": [
    "model = MNIST()\n",
    "\n",
    "def train(model):\n",
    "    model.train()\n",
    "    train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'), batch_size=16, shuffle=True)\n",
    "    opt = paddle.optimizer.SGD(learning_rate=1e-2, parameters=model.parameters())\n",
    "    EPOCH_NUM = 5\n",
    "    for epoch in range(EPOCH_NUM):\n",
    "        for batch_id, data in enumerate(train_loader):\n",
    "            images = norm_img(data[0]).astype('float32')\n",
    "            labels = data[1].astype('int64')\n",
    "\n",
    "            predicts = model(images)\n",
    "\n",
    "            loss = F.cross_entropy(predicts, labels)\n",
    "            avg_loss = paddle.mean(loss)\n",
    "\n",
    "            if batch_id % 1000 == 0:\n",
    "                print(\"epoch_id:{}, batch_id:{}, loss is:{}\".format(epoch, batch_id, avg_loss.numpy()))\n",
    "\n",
    "            avg_loss.backward()\n",
    "            opt.step()\n",
    "            opt.clear_grad()\n",
    "train(model)\n",
    "paddle.save(model.state_dict(), 'mnist.pdparams')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADHpJREFUeJzt3W+oXPWdx/H3d90WwRZUvJpgda9bRFaEtcsQFlwWl2KxS8H0QSV5UCKUxkiFLfSB0SdJHiyYZduuD5aGdM0foU1baP3zQNyKLLiFpThKqHazuxW522YTkytWah9IUb/74J6U23jvzHXmnDmTfN8vkJk5vzNzvpz4uefMnPP7/SIzkVTPH/VdgKR+GH6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0X98Sw3dtVVV+Xi4uIsNymVsrS0xBtvvBEbWXeq8EfEncAjwCXAv2Tmw6PWX1xcZDgcTrNJSSMMBoMNrzvxaX9EXAL8M/BZ4GZge0TcPOnnSZqtab7zbwFezczXMvN3wPeAu9opS1LXpgn/tcCvVr0+2Sz7AxGxMyKGETFcXl6eYnOS2jRN+Nf6UeED/YMz82BmDjJzsLCwMMXmJLVpmvCfBK5b9foTwKnpypE0K9OE/wXgxoi4ISI+CmwDnmqnLEldm/hSX2a+GxH3A//KyqW+Q5n589Yqk9Spqa7zZ+bTwNMt1SJphry9VyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paKmmqU3IpaAt4H3gHczc9BGUZK6N1X4G3+TmW+08DmSZsjTfqmoacOfwI8j4sWI2NlGQZJmY9rT/tsy81REXA08GxH/lZnPr16h+aOwE+D666+fcnOS2jLVkT8zTzWPZ4HHgS1rrHMwMweZOVhYWJhmc5JaNHH4I+KyiPj4uefAZ4BX2ipMUremOe2/Bng8Is59zncz85lWqpLUuYnDn5mvAX/eYi2SZshLfVJRhl8qyvBLRRl+qSjDLxVl+KWi2ujVV94zz4y+veG+++6b6vO3bt06sv2BBx5Yt23Tpk1TbVsXL4/8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1RUZObMNjYYDHI4HM5se7Nyww03jGxfWlqaTSFrWFxcHNk+6h4BgF27drVYjbo2GAwYDoexkXU98ktFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUfbnb8G99947sv3BBx+cUSUfNO4eg3FjDezfv39k+549e0a233PPPeu2HTlyZOR79+3bN7K9y/snxt0fMe7ffPfu3S1W0w2P/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9U1Nj+/BFxCPgccDYzb2mWXQl8H1gEloC7M/PX4zZ2sfbnn9YTTzwxsn3c9e7jx4+3WY5aMMtxMlZruz//EeDO85btBp7LzBuB55rXki4gY8Ofmc8Db563+C7gaPP8KDB6ShlJc2fS7/zXZOZpgObx6vZKkjQLnf/gFxE7I2IYEcPl5eWuNydpgyYN/5mI2AzQPJ5db8XMPJiZg8wcLCwsTLg5SW2bNPxPATua5zuAJ9spR9KsjA1/RBwD/gO4KSJORsSXgIeBOyLiF8AdzWtJF5Cx/fkzc/s6TZ9uuZaytm4dfbFkXPs777yzbtu4/vrj+tRrbeP+TS4E3uEnFWX4paIMv1SU4ZeKMvxSUYZfKsqhuy8Cl1566bpthw8fHvnem266aWR7n8OOX3755SPb33rrrc62vW3btpHt4/brhcAjv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VNXbo7jY5dLc+jM2bN49sf/311yf+7HHX8Y8dOzbxZ/ep7aG7JV2EDL9UlOGXijL8UlGGXyrK8EtFGX6pKPvzqzcHDhwY2T7NdfxxLob++NPyyC8VZfilogy/VJThl4oy/FJRhl8qyvBLRY29zh8Rh4DPAWcz85Zm2V7gy8Bys9pDmfl0V0Xq4rRv375OP39Un/1Rcx1UsZEj/xHgzjWWfzMzb23+M/jSBWZs+DPzeeDNGdQiaYam+c5/f0T8LCIORcQVrVUkaSYmDf+3gE8CtwKnga+vt2JE7IyIYUQMl5eX11tN0oxNFP7MPJOZ72Xm+8C3gS0j1j2YmYPMHCwsLExap6SWTRT+iFg9rOrngVfaKUfSrGzkUt8x4Hbgqog4CewBbo+IW4EEloB7O6xRUgfGhj8zt6+x+NEOalExXfbXB/vsj+MdflJRhl8qyvBLRRl+qSjDLxVl+KWiHLpbnRo3PHeX7LY7mkd+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK6/zqVJfDc48amlvjeeSXijL8UlGGXyrK8EtFGX6pKMMvFWX4paK8zq+pjOuv3+Xw3A7NPR2P/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9U1Njr/BFxHfAYsAl4HziYmY9ExJXA94FFYAm4OzN/3V2pmkd99td3XP7pbOTI/y7wtcz8M+Avga9ExM3AbuC5zLwReK55LekCMTb8mXk6M19qnr8NnACuBe4CjjarHQW2dlWkpPZ9qO/8EbEIfAr4KXBNZp6GlT8QwNVtFyepOxsOf0R8DPgh8NXM/M2HeN/OiBhGxHB5eXmSGiV1YEPhj4iPsBL872Tmj5rFZyJic9O+GTi71nsz82BmDjJzsLCw0EbNklowNvwREcCjwInM/MaqpqeAHc3zHcCT7ZcnqSsb6dJ7G/BF4OWION4sewh4GPhBRHwJ+CXwhW5K1Dyzy+6Fa2z4M/MnQKzT/Ol2y5E0K97hJxVl+KWiDL9UlOGXijL8UlGGXyrKobs1t+yy2y2P/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU/fk10oEDB/ouQR3xyC8VZfilogy/VJThl4oy/FJRhl8qyvBLRY29zh8R1wGPAZuA94GDmflIROwFvgwsN6s+lJlPd1Wo+rF///7OPnvbtm2dfbbG28hNPu8CX8vMlyLi48CLEfFs0/bNzPzH7sqT1JWx4c/M08Dp5vnbEXECuLbrwiR160N954+IReBTwE+bRfdHxM8i4lBEXLHOe3ZGxDAihsvLy2utIqkHGw5/RHwM+CHw1cz8DfAt4JPAraycGXx9rfdl5sHMHGTmYGFhoYWSJbVhQ+GPiI+wEvzvZOaPADLzTGa+l5nvA98GtnRXpqS2jQ1/RATwKHAiM7+xavnmVat9Hnil/fIkdWUjv/bfBnwReDkijjfLHgK2R8StQAJLwL2dVKhe7dixY2T7vn37RrZv3bp13bbDhw9PVJPasZFf+38CxBpNXtOXLmDe4ScVZfilogy/VJThl4oy/FJRhl8qyqG7NdLevXtHtu/atWtk+6ZNm1qsRm3yyC8VZfilogy/VJThl4oy/FJRhl8qyvBLRUVmzm5jEcvA/65adBXwxswK+HDmtbZ5rQusbVJt1vYnmbmh8fJmGv4PbDximJmD3goYYV5rm9e6wNom1VdtnvZLRRl+qai+w3+w5+2PMq+1zWtdYG2T6qW2Xr/zS+pP30d+ST3pJfwRcWdE/HdEvBoRu/uoYT0RsRQRL0fE8YgY9lzLoYg4GxGvrFp2ZUQ8GxG/aB7XnCatp9r2RsT/NfvueET8bU+1XRcR/xYRJyLi5xHxd83yXvfdiLp62W8zP+2PiEuA/wHuAE4CLwDbM/M/Z1rIOiJiCRhkZu/XhCPir4HfAo9l5i3Nsn8A3szMh5s/nFdk5gNzUtte4Ld9z9zcTCizefXM0sBW4B563Hcj6rqbHvZbH0f+LcCrmflaZv4O+B5wVw91zL3MfB5487zFdwFHm+dHWfmfZ+bWqW0uZObpzHypef42cG5m6V733Yi6etFH+K8FfrXq9Unma8rvBH4cES9GxM6+i1nDNc206eemT7+653rON3bm5lk6b2bpudl3k8x43bY+wr/W7D/zdMnhtsz8C+CzwFea01ttzIZmbp6VNWaWnguTznjdtj7CfxK4btXrTwCneqhjTZl5qnk8CzzO/M0+fObcJKnN49me6/m9eZq5ea2ZpZmDfTdPM173Ef4XgBsj4oaI+CiwDXiqhzo+ICIua36IISIuAz7D/M0+/BRwbvbMHcCTPdbyB+Zl5ub1Zpam5303bzNe93KTT3Mp45+AS4BDmfn3My9iDRHxp6wc7WFlZOPv9llbRBwDbmel19cZYA/wBPAD4Hrgl8AXMnPmP7ytU9vtrJy6/n7m5nPfsWdc218B/w68DLzfLH6Ile/Xve27EXVtp4f95h1+UlHe4ScVZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qaj/Bw0QpQndoM8AAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "原始图像形状: (28, 28)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADHpJREFUeJzt3W+oXPWdx/H3d90WwRZUvJpgda9bRFaEtcsQFlwWl2KxS8H0QSV5UCKUxkiFLfSB0SdJHiyYZduuD5aGdM0foU1baP3zQNyKLLiFpThKqHazuxW522YTkytWah9IUb/74J6U23jvzHXmnDmTfN8vkJk5vzNzvpz4uefMnPP7/SIzkVTPH/VdgKR+GH6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0X98Sw3dtVVV+Xi4uIsNymVsrS0xBtvvBEbWXeq8EfEncAjwCXAv2Tmw6PWX1xcZDgcTrNJSSMMBoMNrzvxaX9EXAL8M/BZ4GZge0TcPOnnSZqtab7zbwFezczXMvN3wPeAu9opS1LXpgn/tcCvVr0+2Sz7AxGxMyKGETFcXl6eYnOS2jRN+Nf6UeED/YMz82BmDjJzsLCwMMXmJLVpmvCfBK5b9foTwKnpypE0K9OE/wXgxoi4ISI+CmwDnmqnLEldm/hSX2a+GxH3A//KyqW+Q5n589Yqk9Spqa7zZ+bTwNMt1SJphry9VyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paKmmqU3IpaAt4H3gHczc9BGUZK6N1X4G3+TmW+08DmSZsjTfqmoacOfwI8j4sWI2NlGQZJmY9rT/tsy81REXA08GxH/lZnPr16h+aOwE+D666+fcnOS2jLVkT8zTzWPZ4HHgS1rrHMwMweZOVhYWJhmc5JaNHH4I+KyiPj4uefAZ4BX2ipMUremOe2/Bng8Is59zncz85lWqpLUuYnDn5mvAX/eYi2SZshLfVJRhl8qyvBLRRl+qSjDLxVl+KWi2ujVV94zz4y+veG+++6b6vO3bt06sv2BBx5Yt23Tpk1TbVsXL4/8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1RUZObMNjYYDHI4HM5se7Nyww03jGxfWlqaTSFrWFxcHNk+6h4BgF27drVYjbo2GAwYDoexkXU98ktFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUfbnb8G99947sv3BBx+cUSUfNO4eg3FjDezfv39k+549e0a233PPPeu2HTlyZOR79+3bN7K9y/snxt0fMe7ffPfu3S1W0w2P/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9U1Nj+/BFxCPgccDYzb2mWXQl8H1gEloC7M/PX4zZ2sfbnn9YTTzwxsn3c9e7jx4+3WY5aMMtxMlZruz//EeDO85btBp7LzBuB55rXki4gY8Ofmc8Db563+C7gaPP8KDB6ShlJc2fS7/zXZOZpgObx6vZKkjQLnf/gFxE7I2IYEcPl5eWuNydpgyYN/5mI2AzQPJ5db8XMPJiZg8wcLCwsTLg5SW2bNPxPATua5zuAJ9spR9KsjA1/RBwD/gO4KSJORsSXgIeBOyLiF8AdzWtJF5Cx/fkzc/s6TZ9uuZaytm4dfbFkXPs777yzbtu4/vrj+tRrbeP+TS4E3uEnFWX4paIMv1SU4ZeKMvxSUYZfKsqhuy8Cl1566bpthw8fHvnem266aWR7n8OOX3755SPb33rrrc62vW3btpHt4/brhcAjv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VNXbo7jY5dLc+jM2bN49sf/311yf+7HHX8Y8dOzbxZ/ep7aG7JV2EDL9UlOGXijL8UlGGXyrK8EtFGX6pKPvzqzcHDhwY2T7NdfxxLob++NPyyC8VZfilogy/VJThl4oy/FJRhl8qyvBLRY29zh8Rh4DPAWcz85Zm2V7gy8Bys9pDmfl0V0Xq4rRv375OP39Un/1Rcx1UsZEj/xHgzjWWfzMzb23+M/jSBWZs+DPzeeDNGdQiaYam+c5/f0T8LCIORcQVrVUkaSYmDf+3gE8CtwKnga+vt2JE7IyIYUQMl5eX11tN0oxNFP7MPJOZ72Xm+8C3gS0j1j2YmYPMHCwsLExap6SWTRT+iFg9rOrngVfaKUfSrGzkUt8x4Hbgqog4CewBbo+IW4EEloB7O6xRUgfGhj8zt6+x+NEOalExXfbXB/vsj+MdflJRhl8qyvBLRRl+qSjDLxVl+KWiHLpbnRo3PHeX7LY7mkd+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK6/zqVJfDc48amlvjeeSXijL8UlGGXyrK8EtFGX6pKMMvFWX4paK8zq+pjOuv3+Xw3A7NPR2P/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9U1Njr/BFxHfAYsAl4HziYmY9ExJXA94FFYAm4OzN/3V2pmkd99td3XP7pbOTI/y7wtcz8M+Avga9ExM3AbuC5zLwReK55LekCMTb8mXk6M19qnr8NnACuBe4CjjarHQW2dlWkpPZ9qO/8EbEIfAr4KXBNZp6GlT8QwNVtFyepOxsOf0R8DPgh8NXM/M2HeN/OiBhGxHB5eXmSGiV1YEPhj4iPsBL872Tmj5rFZyJic9O+GTi71nsz82BmDjJzsLCw0EbNklowNvwREcCjwInM/MaqpqeAHc3zHcCT7ZcnqSsb6dJ7G/BF4OWION4sewh4GPhBRHwJ+CXwhW5K1Dyzy+6Fa2z4M/MnQKzT/Ol2y5E0K97hJxVl+KWiDL9UlOGXijL8UlGGXyrKobs1t+yy2y2P/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU/fk10oEDB/ouQR3xyC8VZfilogy/VJThl4oy/FJRhl8qyvBLRY29zh8R1wGPAZuA94GDmflIROwFvgwsN6s+lJlPd1Wo+rF///7OPnvbtm2dfbbG28hNPu8CX8vMlyLi48CLEfFs0/bNzPzH7sqT1JWx4c/M08Dp5vnbEXECuLbrwiR160N954+IReBTwE+bRfdHxM8i4lBEXLHOe3ZGxDAihsvLy2utIqkHGw5/RHwM+CHw1cz8DfAt4JPAraycGXx9rfdl5sHMHGTmYGFhoYWSJbVhQ+GPiI+wEvzvZOaPADLzTGa+l5nvA98GtnRXpqS2jQ1/RATwKHAiM7+xavnmVat9Hnil/fIkdWUjv/bfBnwReDkijjfLHgK2R8StQAJLwL2dVKhe7dixY2T7vn37RrZv3bp13bbDhw9PVJPasZFf+38CxBpNXtOXLmDe4ScVZfilogy/VJThl4oy/FJRhl8qyqG7NdLevXtHtu/atWtk+6ZNm1qsRm3yyC8VZfilogy/VJThl4oy/FJRhl8qyvBLRUVmzm5jEcvA/65adBXwxswK+HDmtbZ5rQusbVJt1vYnmbmh8fJmGv4PbDximJmD3goYYV5rm9e6wNom1VdtnvZLRRl+qai+w3+w5+2PMq+1zWtdYG2T6qW2Xr/zS+pP30d+ST3pJfwRcWdE/HdEvBoRu/uoYT0RsRQRL0fE8YgY9lzLoYg4GxGvrFp2ZUQ8GxG/aB7XnCatp9r2RsT/NfvueET8bU+1XRcR/xYRJyLi5xHxd83yXvfdiLp62W8zP+2PiEuA/wHuAE4CLwDbM/M/Z1rIOiJiCRhkZu/XhCPir4HfAo9l5i3Nsn8A3szMh5s/nFdk5gNzUtte4Ld9z9zcTCizefXM0sBW4B563Hcj6rqbHvZbH0f+LcCrmflaZv4O+B5wVw91zL3MfB5487zFdwFHm+dHWfmfZ+bWqW0uZObpzHypef42cG5m6V733Yi6etFH+K8FfrXq9Unma8rvBH4cES9GxM6+i1nDNc206eemT7+653rON3bm5lk6b2bpudl3k8x43bY+wr/W7D/zdMnhtsz8C+CzwFea01ttzIZmbp6VNWaWnguTznjdtj7CfxK4btXrTwCneqhjTZl5qnk8CzzO/M0+fObcJKnN49me6/m9eZq5ea2ZpZmDfTdPM173Ef4XgBsj4oaI+CiwDXiqhzo+ICIua36IISIuAz7D/M0+/BRwbvbMHcCTPdbyB+Zl5ub1Zpam5303bzNe93KTT3Mp45+AS4BDmfn3My9iDRHxp6wc7WFlZOPv9llbRBwDbmel19cZYA/wBPAD4Hrgl8AXMnPmP7ytU9vtrJy6/n7m5nPfsWdc218B/w68DLzfLH6Ile/Xve27EXVtp4f95h1+UlHe4ScVZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qaj/Bw0QpQndoM8AAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "采样后图像形状: (28, 28)\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "img_path = '0.png'\n",
    "im = Image.open(img_path)\n",
    "plt.imshow(im)\n",
    "plt.show()\n",
    "\n",
    "im = im.convert('L')\n",
    "print('原始图像形状:', np.array(im).shape)\n",
    "\n",
    "# 使用 Image.LANCZOS 方式采样原始图像\n",
    "im = im.resize((28, 28), Image.ANTIALIAS)\n",
    "plt.imshow(im, cmap='gray')  # 显示灰度图像\n",
    "plt.show()\n",
    "print(\"采样后图像形状:\", np.array(im).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "本次预测的数字是:  7\n"
     ]
    }
   ],
   "source": [
    "def load_imafe(img_path):\n",
    "    im = lmage.open(img_path).convert('L')\n",
    "    im = im.resize((28,28), image.ANTIALIAS)\n",
    "    im = np.array(im).reshape(1,-1).astype(np.float32)\n",
    "    im = 1 - im / 255\n",
    "    return im\n",
    "model = MNIST()\n",
    "params_file_path = 'mnist.pdparams'\n",
    "param_dict = paddle.load(params_file_path)\n",
    "model.load_dict(param_dict)\n",
    "model.eval()\n",
    "tensor_img = load_image(img_path)\n",
    "result = model(paddle.to_tensor(tensor_img))\n",
    "print(\"本次预测的数字是: \", np.argsort(result.numpy())[0][-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
